{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf048c61",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "# ^^^ pyforest auto-imports - don't write above this line\n",
    "sys.path.insert(0, str(Path(\"../../\").resolve()))\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17d14eb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from data_reader import CINC2022Reader, CINC2016Reader, EPHNOGRAMReader\n",
    "from dataset import CinC2022Dataset\n",
    "from models import (\n",
    "    CRNN_CINC2022,\n",
    "    SEQ_LAB_NET_CINC2022,\n",
    "    UNET_CINC2022,\n",
    "    Wav2Vec2_CINC2022,\n",
    "    HFWav2Vec2_CINC2022,\n",
    ")\n",
    "from cfg import TrainCfg, ModelCfg\n",
    "from trainer import CINC2022Trainer, _MODEL_MAP, _set_task, collate_fn\n",
    "from utils.plot import plot_spectrogram\n",
    "\n",
    "from tqdm.auto import tqdm\n",
    "import torchaudio\n",
    "from copy import deepcopy\n",
    "\n",
    "from torch.nn.parallel import DistributedDataParallel as DDP, DataParallel as DP\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "CRNN_CINC2022.__DEBUG__ = False\n",
    "Wav2Vec2_CINC2022.__DEBUG__ = False\n",
    "HFWav2Vec2_CINC2022.__DEBUG__ = False\n",
    "CinC2022Dataset.__DEBUG__ = False\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02d24e6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "db_dir = \"/data1/Jupyter-Data/CinC2022/\"  # replace with the data directory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be00dae4",
   "metadata": {},
   "outputs": [],
   "source": [
    "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "if ModelCfg.torch_dtype == torch.float64:\n",
    "    torch.set_default_tensor_type(torch.DoubleTensor)\n",
    "    DTYPE = np.float64\n",
    "else:\n",
    "    DTYPE = np.float32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19622e5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# task = \"classification\"\n",
    "task = \"multi_task\"\n",
    "\n",
    "train_config = deepcopy(TrainCfg)\n",
    "# train_config.db_dir = data_folder\n",
    "# train_config.model_dir = model_folder\n",
    "# train_config.final_model_filename = _ModelFilename\n",
    "train_config.debug = True\n",
    "\n",
    "train_config.db_dir = db_dir\n",
    "\n",
    "# train_config.n_epochs = 100\n",
    "# train_config.batch_size = 24  # 16G (Tesla T4)\n",
    "# train_config.log_step = 20\n",
    "# # train_config.max_lr = 1.5e-3\n",
    "# train_config.early_stopping.patience = 20\n",
    "\n",
    "train_config[task].model_name = \"crnn\"  # \"wav2vec2_hf\"\n",
    "\n",
    "train_config[task].cnn_name = \"tresnetF\"  # \"resnet_nature_comm_bottle_neck_se\"\n",
    "# train_config[task].rnn_name = \"none\"  # \"none\", \"lstm\"\n",
    "# train_config[task].attn_name = \"se\"  # \"none\", \"se\", \"gc\", \"nl\"\n",
    "\n",
    "_set_task(task, train_config)\n",
    "\n",
    "model_config = deepcopy(ModelCfg[task])\n",
    "\n",
    "# adjust model choices if needed\n",
    "model_config.model_name = train_config[task].model_name\n",
    "# print(model_name)\n",
    "if \"cnn\" in model_config[model_config.model_name]:\n",
    "    model_config[model_config.model_name].cnn.name = train_config[task].cnn_name\n",
    "if \"rnn\" in model_config[model_config.model_name]:\n",
    "    model_config[model_config.model_name].rnn.name = train_config[task].rnn_name\n",
    "if \"attn\" in model_config[model_config.model_name]:\n",
    "    model_config[model_config.model_name].attn.name = train_config[task].attn_name\n",
    "\n",
    "# model_config.wav2vec2.cnn.name = \"resnet_nature_comm_bottle_neck_se\"\n",
    "# model_config.wav2vec2.encoder.name = \"wav2vec2_nano\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13aae089",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a29e5957",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_cls = _MODEL_MAP[model_config.model_name]\n",
    "model_cls.__DEBUG__ = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "894d8c90",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model_cls(config=model_config)\n",
    "if torch.cuda.device_count() > 1:\n",
    "    model = DP(model)\n",
    "    # model = DDP(model)\n",
    "model.to(device=DEVICE);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2322fd35",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.module.module_size, model.module.module_size_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad39ec81",
   "metadata": {},
   "outputs": [],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6079dd6d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92bb0f3d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3faa59f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_train = CinC2022Dataset(train_config, task, training=True, lazy=True)\n",
    "ds_test = CinC2022Dataset(train_config, task, training=False, lazy=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eed4ce9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_train._load_all_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40462919",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_test._load_all_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90b6f218",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df736ca5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65f656bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = CINC2022Trainer(\n",
    "    model=model,\n",
    "    model_config=model_config,\n",
    "    train_config=train_config,\n",
    "    device=DEVICE,\n",
    "    lazy=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16aca01a",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer._setup_dataloaders(ds_train, ds_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d7ed365",
   "metadata": {},
   "outputs": [],
   "source": [
    "best_state_dict = trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e6184ea",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9945c0e9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0fb2b222",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "54063906",
   "metadata": {},
   "source": [
    "## Inspect trained models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e84a372",
   "metadata": {},
   "outputs": [],
   "source": [
    "from models import Wav2Vec2_CINC2022, CRNN_CINC2022\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e54564aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpt = CRNN_CINC2022.from_checkpoint(\n",
    "    \"./saved_models/BestModel_task-multi_task_CRNN_CINC2022_epoch41_08-11_02-38_metric_-16272.44.pth.tar\"\n",
    "    # replace with a saved model\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40c9eba6",
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpt[0].config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "860dcc74",
   "metadata": {},
   "outputs": [],
   "source": [
    "best_model = ckpt[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e80f6ad0",
   "metadata": {},
   "outputs": [],
   "source": [
    "best_model = best_model.to(\"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "150382e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "dl = DataLoader(\n",
    "    dataset=ds_train,\n",
    "    batch_size=4,\n",
    "    shuffle=True,\n",
    "    num_workers=4,\n",
    "    pin_memory=True,\n",
    "    drop_last=False,\n",
    "    collate_fn=collate_fn,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70c13ecd",
   "metadata": {},
   "outputs": [],
   "source": [
    "for batch in dl:\n",
    "    labels = batch\n",
    "    waveforms = labels.pop(\"waveforms\")\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8eb3a45",
   "metadata": {},
   "outputs": [],
   "source": [
    "best_model(waveforms, labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec8dbcef",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d68967d4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e28e78a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "664b1bce",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ed0ba02",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
